import numpy as np
from scipy.linalg import svd
from PIL import Image

def decompose_image_channel(channel, ranks):
    """
    Decompose a single image channel into k frequency bands based on provided ranks.
    
    :param channel: 2D array representation of an image channel.
    :param ranks: List of ranks to split the SVD. Assumes ranks are sorted in ascending order.
    :return: List of image channels corresponding to each frequency band.
    """
    # 对单个颜色通道进行SVD
    U, s, Vt = svd(channel, full_matrices=False)
    m, n = U.shape[0], Vt.shape[0]
    S = np.zeros((m, n))
    np.fill_diagonal(S, s)
    
    # Decompose into k frequency bands
    decomposed_channels = []
    prev_rank = 0
    for rank in ranks + [min(m, n)]:  # Add the min(m, n) to cover the last frequency band
        S_k = np.zeros_like(S)
        S_k[prev_rank:rank, prev_rank:rank] = np.diag(s[prev_rank:rank])
        channel_k = U @ S_k @ Vt
        decomposed_channels.append(channel_k)
        prev_rank = rank  # Update the previous rank

    return decomposed_channels

def combine_frequency_bands(bands):
    """
    Combine k frequency bands into a single image channel.
    
    :param bands: List of 2D arrays, each representing an image frequency band.
    :return: Combined 2D array representing the original image channel.
    """
    return np.sum(bands, axis=0)

def save_decomposed_and_combined_images(image_path, ranks):
    # 加载彩色图像
    image = Image.open(image_path)
    R, G, B = image.split()

    # 转换为矩阵
    R_matrix = np.array(R, dtype=float)
    G_matrix = np.array(G, dtype=float)
    B_matrix = np.array(B, dtype=float)

    # 对每个颜色通道分别进行SVD分解
    R_bands = decompose_image_channel(R_matrix, ranks)
    G_bands = decompose_image_channel(G_matrix, ranks)
    B_bands = decompose_image_channel(B_matrix, ranks)

    # 保存每个频段的图片和合成的图片
    for i in range(len(ranks) + 1):
        frequency_band_image = Image.merge('RGB', [
            Image.fromarray(R_bands[i].clip(0, 255).astype('uint8')),
            Image.fromarray(G_bands[i].clip(0, 255).astype('uint8')),
            Image.fromarray(B_bands[i].clip(0, 255).astype('uint8'))
        ])
        frequency_band_image.save(f'frequency_band_{i}.png')

    # 合并所有频段
    R_combined = combine_frequency_bands(R_bands)
    G_combined = combine_frequency_bands(G_bands)
    B_combined = combine_frequency_bands(B_bands)

    # 将三个颜色通道组合成彩色图像
    combined_image = Image.merge('RGB', [
        Image.fromarray(R_combined.clip(0, 255).astype('uint8')),
        Image.fromarray(G_combined.clip(0, 255).astype('uint8')),
        Image.fromarray(B_combined.clip(0, 255).astype('uint8'))
    ])
    combined_image.save('combined_image.png')

# 调用函数，传入图片路径和频段分解的秩
save_decomposed_and_combined_images('00000.png', [8, 30])